import os
import numpy as np
from glob import glob
from compute_recall import compute_recall_from_file
from scripts.align_ply_from_ape_log import align_ply_from_ape_log
import socket
if __name__ == "__main__":
    hostname = socket.gethostname()

    stride = 10
    IMU_poseinit_after = 20
    
    output_folder = f"/home/zihzhu/data/output_rpng/code_release_test10"
    config_file = "config/rpng.yaml"
    # traj_name = 'traj_full_afterBA.txt'
    # traj_name = 'traj_full_beforeBA.txt'
    traj_name = 'traj_kf_beforeBA.txt'
    # traj_name = 'traj_kf_afterBA.txt'
    
    if hostname == 'zihzhu':
        seqs = sorted(glob('/home/zihzhu/data/Datasets/rpngar/*'))
    else:
        seqs = sorted(glob('/cluster/project/cvg/zihzhu/Datasets/rpngar/*'))
    print(output_folder)
    output_folder_type=output_folder.split('/')[-1]
    os.makedirs(os.path.join(output_folder, 'pcs'), exist_ok=True)
    excel=''
    excel_scale=''
    excel_recall=''
    ate_values = []
    scale_errors = []
    recalls = []
    for i, seq in enumerate(seqs[:]):
        if 'TimeStamps' in seq:
            continue
        if 'stride' in seq:
            continue

        name = os.path.basename(seq)
        
        os.makedirs(os.path.join(output_folder, name), exist_ok=True)
        print("##################################  Processing: {}  ##################################".format(seq))

        cmd = f'python demo.py --calib calib/rpngar.txt  \
                --imagedir {seq}/rgb \
                --config {config_file} \
                --stride {stride} \
                --IMU_poseinit_after {IMU_poseinit_after} \
                --imufile {seq}/imu.txt --output {output_folder}/{name} \
                --undistort > {output_folder}/{name}/log.txt'
                # --gsmapping
        print(cmd)
        if  (not os.path.exists(f'{output_folder}/{name}/{traj_name}')):
            os.makedirs(os.path.join(output_folder, name), exist_ok=True)
            
            gpu_mem=20
            time=4
            # os.system(f"sbatch  --time={time}:00:00 -A ls_polle -n 1  --cpus-per-task=16 --mem-per-cpu=10G  --gpus=1 --gres=gpumem:{gpu_mem}g --output=slurms/{output_folder_type}_{name}.out  --wrap '{cmd}'")
            # print(cmd)
            
            os.system(cmd)
            # continue
        log_ape_file = f'{output_folder}/{name}/log_ape_{traj_name.split(".")[0]}.txt'
        cmd = f'evo_ape tum -vas --no_warnings --plot_mode xy --save_plot {output_folder}/{name}/ape_se3_{traj_name.split(".")[0]}.png --save_results {output_folder}/{name}/ape_results.zip \
            {seq}/gt_imu.txt {output_folder}/{name}/{traj_name} > {log_ape_file}'
        if True or (not os.path.exists(f'{output_folder}/{name}/ape_se3_{traj_name.split(".")[0]}_map.png')):
            os.system(cmd)
        try:
            ATE = float([l for l in open(f'{log_ape_file}').readlines() if 'rmse' in l][-1].split('\t')[-1])
            scale = float([l for l in open(f'{log_ape_file}').readlines() if 'Scale correction' in l][-1].split(' ')[-1])
            # print(f'APE before: {ATE:.4f}, scale: {scale:.4f}')
            ATE_cm = ATE * 100
        except:
            pass
        
        # get aligned 3dgs
        try:
            if True:
                if 'beforeBA' in traj_name:
                    before_align_3dgs = f'{output_folder}/{name}/3dgs_before_final.ply'
                    after_align_3dgs = f'{output_folder}/{name}/3dgs_before_final_aligned.ply'
                    info = align_ply_from_ape_log(
                        ape_log_path=f'{output_folder}/{name}/log_ape_traj_kf_beforeBA.txt',
                        ply_in=before_align_3dgs,
                        ply_out=after_align_3dgs,
                    )
                elif 'afterBA' in traj_name:
                    before_align_3dgs = f'{output_folder}/{name}/3dgs_final.ply'
                    after_align_3dgs = f'{output_folder}/{name}/3dgs_final_aligned.ply'
                    info = align_ply_from_ape_log(
                        ape_log_path=f'{output_folder}/{name}/log_ape_traj_kf_afterBA.txt',
                        ply_in=before_align_3dgs,
                        ply_out=after_align_3dgs,
                    )
                # print("Aligned and saved:", info["out"])
                # print("Scale:", info["scale"])
                # print("R:\n", info["R"])
                # print("t:", info["t"])
        except Exception as e:
            print(e)
            pass
        scale_err_percent = abs(1 - scale) * 100
        
        _, __, recall = compute_recall_from_file(f'{seq}/gt_imu.txt', f'{output_folder}/{name}/{traj_name}')
        
        print(f'Recall: {recall:.2f}%')
        # exit(0)
        ate_values.append(ATE_cm)
        scale_errors.append(scale_err_percent)
        recalls.append(recall)
        excel=excel+' ,'+str(ATE_cm)
        excel_scale=excel_scale+' ,'+str(scale_err_percent)
        excel_recall=excel_recall+' ,'+str(recall)
        print(f'APE before: {ATE:.4f}, scale: {scale:.4f}, Rel. scale error: {abs(1-scale):.4f}')
    print(excel)
    print(excel_scale)
    print(excel_recall)
    
    
    # Build LaTeX lines
    ate_line = "ATE [cm] & " + " & ".join(f"{v:.2f}" for v in ate_values) + f" & {np.mean(ate_values):.2f} \\\\"
    scale_line = "Scale error [\\%] & " + " & ".join(f"{v:.2f}" for v in scale_errors) + f" & {np.mean(scale_errors):.2f} \\\\"
    recall_line = "Ours Recall [\\%] & " + " & ".join(f"{v:.2f}" for v in recalls) + f" & {np.mean(recalls):.2f} \\\\"
    
    # Print
    print(ate_line)
    print(scale_line)
    print(recall_line)